{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Hugging face Babelscape/rebel-large model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install python-dotenv\n",
    "!pip install datasets\n",
    "!pip install langchain\n",
    "!pip install transformers\n",
    "!pip install neo4j\n",
    "!pip install -U huggingface_hub\n",
    "pip install torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import random\n",
    "import json\n",
    "import hashlib\n",
    "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
    "from tqdm.auto import tqdm\n",
    "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 126,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_community.document_loaders import PyPDFLoader\n",
    "from datetime import datetime\n",
    "\n",
    "loader = PyPDFLoader('../data/Football_news.pdf')\n",
    "start_time = datetime.now()\n",
    "pages = loader.load_and_split()\n",
    "\n",
    "contents=''\n",
    "for i in range(0,len(pages)):\n",
    "    contents=' '.join([contents,pages[i].page_content.replace('\\n',' ')])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 127,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Creating chunks using langchain RecursiveCharacterTextSplitter\n",
    "\n",
    "text_splitter = RecursiveCharacterTextSplitter(\n",
    "    chunk_size=200,\n",
    "    chunk_overlap=20,\n",
    "    length_function=len,\n",
    "    is_separator_regex=False,\n",
    ")\n",
    "\n",
    "chunks = text_splitter.split_text(contents)\n",
    "#chunk_documents = text_splitter.create_documents([contents])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 128,
   "metadata": {},
   "outputs": [],
   "source": [
    "#to add id to the chunks\n",
    "# m = hashlib.md5()\n",
    "# tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
    "\n",
    "# data = []\n",
    "# m = hashlib.md5()\n",
    "# m.update(contents.encode('utf-8'))\n",
    "# uid = m.hexdigest()[:12]\n",
    "\n",
    "# for i, chunk in enumerate(chunks):\n",
    "#         data.append({\n",
    "#             'id': f'{uid}-{i}',\n",
    "#             'text': chunk\n",
    "#         })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 129,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\" Extract triplets from text\n",
    "     \n",
    "     input : string text\n",
    "     ouput: list of tuple (triplets in the form of [source, relation, target])\n",
    "\n",
    "\"\"\"\n",
    "\n",
    "def extract_triplets(text):\n",
    "    triplets = []\n",
    "    relation, subject, relation, object_ = '', '', '', ''\n",
    "    text = text.strip()\n",
    "    current = 'x'\n",
    "    for token in text.replace(\"<s>\", \"\").replace(\"<pad>\", \"\").replace(\"</s>\", \"\").split():\n",
    "        if token == \"<triplet>\":\n",
    "            current = 't'\n",
    "            if relation != '':\n",
    "                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})\n",
    "                relation = ''\n",
    "            subject = ''\n",
    "        elif token == \"<subj>\":\n",
    "            current = 's'\n",
    "            if relation != '':\n",
    "                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})\n",
    "            object_ = ''\n",
    "        elif token == \"<obj>\":\n",
    "            current = 'o'\n",
    "            relation = ''\n",
    "        else:\n",
    "            if current == 't':\n",
    "                subject += ' ' + token\n",
    "            elif current == 's':\n",
    "                object_ += ' ' + token\n",
    "            elif current == 'o':\n",
    "                relation += ' ' + token\n",
    "    if subject != '' and relation != '' and object_ != '':\n",
    "        triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})\n",
    "    return triplets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 130,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"Babelscape/rebel-large\")\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(\"Babelscape/rebel-large\")\n",
    "\n",
    "gen_kwargs = {\n",
    "    \"max_length\": 256,\n",
    "    \"length_penalty\": 0,\n",
    "    \"num_beams\": 3,\n",
    "    \"num_return_sequences\": 1,\n",
    "}\n",
    "\n",
    "triples = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 131,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_triples(texts):\n",
    "\n",
    "  model_inputs = tokenizer(texts, max_length=512, padding=True, truncation=True, return_tensors='pt')\n",
    "  generated_tokens = model.generate(\n",
    "      model_inputs[\"input_ids\"].to(model.device),\n",
    "      attention_mask=model_inputs[\"attention_mask\"].to(model.device),\n",
    "      **gen_kwargs\n",
    "  )\n",
    "  decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)\n",
    "  for idx, sentence in enumerate(decoded_preds):\n",
    "      et = extract_triplets(sentence)\n",
    "      for t in et:\n",
    "        triples.append((t['head'], t['type'], t['tail']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in tqdm(range(0, len(chunks), 2)):\n",
    "    try:\n",
    "        texts = [chunks[i], chunks[i+1]]\n",
    "    except:\n",
    "        texts = [chunks[i]]\n",
    "    generate_triples(texts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "distinct_triples = list(set(triples))\n",
    "distinct_triples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 134,
   "metadata": {},
   "outputs": [],
   "source": [
    "import neo4j\n",
    "from neo4j import GraphDatabase\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "load_dotenv()\n",
    "\n",
    "# Connect to the Neo4j database\n",
    "URI = os.environ['NEO4J_URI']\n",
    "USER = os.environ['NEO4J_USERNAME']\n",
    "PASSWORD = os.environ['NEO4J_PASSWORD']\n",
    "AUTH = (USER, PASSWORD)\n",
    "\n",
    "driver = GraphDatabase.driver(URI, auth=AUTH)\n",
    "driver.verify_authentication()\n",
    "\n",
    "def create_triplets(tx, triplet) :\n",
    "    tx.run(\"MERGE (s:Node {name: $source})\", source=triplet[0]) \n",
    "    tx.run(\"MERGE (t:Node {name: $target})\", target=triplet[2]) \n",
    "    tx.run(\"MATCH (s:Node {name: $source}), (t:Node {name: $target}) MERGE (s)-[r:Relation {name: $relation} ]->(t)\" ,\n",
    "        source=triplet[0], relation=triplet[1], target=triplet[2])\n",
    "    \n",
    "with driver.session() as session:\n",
    "    for triplet in distinct_triples:\n",
    "        session.execute_write(create_triplets, triplet)\n",
    "           \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 135,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "end_time = datetime.now() \n",
    "\n",
    "LLM = \"Rebel-large\"\n",
    "file = loader.file_path\n",
    "processed_time = (end_time - start_time).__str__()\n",
    "distinct_nodes = list(set(node for triplet in distinct_triples for node in triplet[::2]))\n",
    "node_count = len(distinct_nodes)\n",
    "relation_count = len(distinct_triples)\n",
    "nodes = distinct_nodes\n",
    "relations = [triplet[1] for triplet in distinct_triples]\n",
    "\n",
    "\n",
    "llm_data = {\"LLM\" : LLM, \n",
    "            \"File\" : file, \n",
    "            \"Processing Time\" : processed_time, \n",
    "            \"Node count\" : node_count, \n",
    "            \"Relation count\" : relation_count,\n",
    "            \"Nodes\" : distinct_nodes,\n",
    "            \"Relations\" : relations }\n",
    "    \n",
    "json_file_path = '../data/llm_comparision.json'\n",
    "with open(json_file_path, 'r') as json_file :\n",
    "    data = json.load(json_file)\n",
    "    \n",
    "if isinstance(data, dict) :\n",
    "    data = [data]\n",
    "    \n",
    "data.append(llm_data)   \n",
    "\n",
    "with open(json_file_path, 'w') as json_file :\n",
    "    json.dump(data, json_file, indent=4)     \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Observations :\n",
    "\n",
    "Graph dont have any properties other than name \n",
    "\n",
    "Node count = 23\n",
    "Relation count = 16"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
